{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2021-03-29 08:23:08] Try to use the default NATS-Bench (size) path from fast_mode=True and path=None.\n",
      "[2021-03-29 08:23:08] Create NATS-Bench (size) done with 0/32768 architectures avaliable.\n",
      "\n",
      "API create done: NATSsize(0/32768 architectures, fast_mode=True, file=None)\n",
      "\n",
      "[2021-03-29 08:23:08] Call the get_more_info function with index=1234, dataset=cifar10, iepoch=None, hp=12, and is_random=True.\n",
      "[2021-03-29 08:23:08] Call query_index_by_arch with arch=1234\n",
      "[2021-03-29 08:23:08] Call clear_params with archive_root=/Users/xuanyidong/.torch/NATS-sss-v1_0-50262-simple and index=1234\n",
      "{'comment': 'In this dict, train-loss/accuracy/time is the metric on the '\n",
      "            'train+valid sets of CIFAR-10. The test-loss/accuracy/time is the '\n",
      "            'performance of the CIFAR-10 test set after training on the '\n",
      "            'train+valid sets by 12 epochs. The per-time and total-time '\n",
      "            'indicate the per epoch and total time costs, respectively.',\n",
      " 'test-accuracy': 83.87,\n",
      " 'test-all-time': 8.31445026397705,\n",
      " 'test-loss': 0.4872739363670349,\n",
      " 'test-per-time': 0.6928708553314209,\n",
      " 'train-accuracy': 85.74,\n",
      " 'train-all-time': 69.73253917694092,\n",
      " 'train-loss': 0.4183172229385376,\n",
      " 'train-per-time': 5.811044931411743}\n"
     ]
    }
   ],
   "source": [
    "from nats_bench import create\n",
    "from pprint import pprint\n",
    "\n",
    "# Create the API instance for the size search space in NATS\n",
    "api = create(None, 'sss', fast_mode=True, verbose=True)\n",
    "print('\\nAPI create done: {:}\\n'.format(api))\n",
    "\n",
    "\n",
    "info = api.get_more_info(1234, 'cifar10')\n",
    "pprint(info)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2021-03-29 08:23:12] Call the get_more_info function with index=1234, dataset=cifar10, iepoch=None, hp=90, and is_random=True.\n",
      "[2021-03-29 08:23:12] Call query_index_by_arch with arch=1234\n",
      "[2021-03-29 08:23:12] Call _prepare_info with index=1234 skip because it is in arch2infos_dict\n",
      "{'comment': 'In this dict, train-loss/accuracy/time is the metric on the '\n",
      "            'train+valid sets of CIFAR-10. The test-loss/accuracy/time is the '\n",
      "            'performance of the CIFAR-10 test set after training on the '\n",
      "            'train+valid sets by 90 epochs. The per-time and total-time '\n",
      "            'indicate the per epoch and total time costs, respectively.',\n",
      " 'test-accuracy': 89.4,\n",
      " 'test-all-time': 62.35837697982788,\n",
      " 'test-loss': 0.3388326271057129,\n",
      " 'test-per-time': 0.6928708553314209,\n",
      " 'train-accuracy': 95.206,\n",
      " 'train-all-time': 522.9940438270569,\n",
      " 'train-loss': 0.14320597895622253,\n",
      " 'train-per-time': 5.811044931411743}\n"
     ]
    }
   ],
   "source": [
    "info = api.get_more_info(1234, 'cifar10', hp='90')\n",
    "pprint(info)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2021-03-29 08:23:15] Call the get_cost_info function with index=12, dataset=cifar10, and hp=12.\n",
      "[2021-03-29 08:23:15] Call clear_params with archive_root=/Users/xuanyidong/.torch/NATS-sss-v1_0-50262-simple and index=12\n",
      "Call query_meta_info_by_index with arch_index=12, hp=12\n",
      "[2021-03-29 08:23:15] Call _prepare_info with index=12 skip because it is in arch2infos_dict\n",
      "{'T-ori-test@epoch': 0.6709375381469727,\n",
      " 'T-ori-test@total': 8.051250457763672,\n",
      " 'T-train@epoch': 5.539922475814819,\n",
      " 'T-train@total': 66.47906970977783,\n",
      " 'flops': 7.991706,\n",
      " 'latency': 0.014862352974560795,\n",
      " 'params': 0.067378}\n"
     ]
    }
   ],
   "source": [
    "# Query the flops, params, latency. info is a dict.\n",
    "info = api.get_cost_info(12, 'cifar10')\n",
    "pprint(info)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2021-03-29 08:23:17] Call the get_more_info function with index=1234, dataset=cifar10, iepoch=None, hp=12, and is_random=True.\n",
      "[2021-03-29 08:23:17] Call query_index_by_arch with arch=1234\n",
      "[2021-03-29 08:23:17] Call _prepare_info with index=1234 skip because it is in arch2infos_dict\n",
      "{'comment': 'In this dict, train-loss/accuracy/time is the metric on the '\n",
      "            'train+valid sets of CIFAR-10. The test-loss/accuracy/time is the '\n",
      "            'performance of the CIFAR-10 test set after training on the '\n",
      "            'train+valid sets by 12 epochs. The per-time and total-time '\n",
      "            'indicate the per epoch and total time costs, respectively.',\n",
      " 'test-accuracy': 84.28,\n",
      " 'test-all-time': 8.31445026397705,\n",
      " 'test-loss': 0.46498328766822816,\n",
      " 'test-per-time': 0.6928708553314209,\n",
      " 'train-accuracy': 86.004,\n",
      " 'train-all-time': 69.73253917694092,\n",
      " 'train-loss': 0.405061281375885,\n",
      " 'train-per-time': 5.811044931411743}\n"
     ]
    }
   ],
   "source": [
    "info = api.get_more_info(1234, 'cifar10', hp='12', is_random=True)\n",
    "pprint(info)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2021-03-29 08:23:20] Call the get_more_info function with index=1234, dataset=cifar10, iepoch=None, hp=12, and is_random=True.\n",
      "[2021-03-29 08:23:20] Call query_index_by_arch with arch=1234\n",
      "[2021-03-29 08:23:20] Call _prepare_info with index=1234 skip because it is in arch2infos_dict\n",
      "{'comment': 'In this dict, train-loss/accuracy/time is the metric on the '\n",
      "            'train+valid sets of CIFAR-10. The test-loss/accuracy/time is the '\n",
      "            'performance of the CIFAR-10 test set after training on the '\n",
      "            'train+valid sets by 12 epochs. The per-time and total-time '\n",
      "            'indicate the per epoch and total time costs, respectively.',\n",
      " 'test-accuracy': 83.87,\n",
      " 'test-all-time': 8.31445026397705,\n",
      " 'test-loss': 0.4872739363670349,\n",
      " 'test-per-time': 0.6928708553314209,\n",
      " 'train-accuracy': 85.74,\n",
      " 'train-all-time': 69.73253917694092,\n",
      " 'train-loss': 0.4183172229385376,\n",
      " 'train-per-time': 5.811044931411743}\n"
     ]
    }
   ],
   "source": [
    "# The same code as above, but return the different performance because we set is_random=True\n",
    "info = api.get_more_info(1234, 'cifar10', hp='12', is_random=True)\n",
    "pprint(info)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
